import os
import numpy as np
import xarray as xr
import hvplot.xarray
import holoviews as hv
import pandas as pd
from datetime import date
import pyproj 
import scipy.interpolate 
from glob import glob
import matplotlib.pyplot as plt
import numpy.ma as ma 
from scipy.interpolate import griddata
from scipy.spatial import KDTree
from utils.read_data_utils import read_is2_data, read_book_data # This allows us to read the ICESAT2 data directly from the google storage bucket
from utils.wrangling_utils import is2_interp2d # Interpolate ICESat2 variables using CDR data 

# Ignore warnings in the notebook to improve display
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from cartopy.mpl.geoaxes import GeoAxes
GeoAxes._pcolormesh_patched = Axes.pcolormesh # Helps avoid some weird issues with the polar projection 
from utils.plotting_utils import interactiveArcticMaps, staticArcticMaps

from astropy.convolution import convolve
from astropy.convolution import Gaussian2DKernel
# Read in data 
book_ds = read_book_data() 
cdr_da = book_ds["cdr_seaice_conc_monthly"]
is2_da = book_ds["ice_thickness"]

# Just grab one month 
date = "Apr 2019"
cdr_da = cdr_da.sel(time=date)[0]
is2_da = is2_da.sel(time=date)[0]

# Grab data as numpy arrays 
lats = is2_da['latitude'].values
lons = is2_da['longitude'].values
xgrid = is2_da['xgrid'].values
ygrid = is2_da['ygrid'].values
is2_np = is2_da.copy().values
cdr_np = cdr_da.copy().values
show_plots = True

# Perform smoothing
kernel = Gaussian2DKernel(x_stddev=0.5)
is2_np_gaus = is2_np.copy()
is2_np_gaus[np.where(cdr_np<0.5)]=np.nan
is2_np_gaus = convolve(is2_np_gaus, kernel)
#is2_np_gaus = convolve(is2_np_gaus, kernel)
is2_np_gaus[np.where(cdr_np<0.5)]=np.nan
#is2_np_gaus[np.where(lats>88)]=np.nan

# KD Tree 
distance_m = 100000
xS = xgrid[np.where((np.isfinite(is2_np)))]
yS = ygrid[np.where((np.isfinite(is2_np)))]
grid_points = np.c_[xgrid.ravel(), ygrid.ravel()]
tree = KDTree(np.c_[xS, yS])
dist, _ = tree.query(grid_points, k=1)
dist = dist.reshape(xgrid.shape)
is2_kdtree = is2_np_gaus.copy()
is2_kdtree[np.where(dist>distance_m)] = np.nan

# Convert to xr.DataArray
gaus_da = xr.DataArray(data=is2_np_gaus, dims=is2_da.dims, coords=is2_da.coords, attrs=is2_da.attrs, name=is2_da.name)
dist_da = xr.DataArray(data=dist, dims=is2_da.dims, coords=is2_da.coords, attrs={"long_name":"distance from nearest nan"}, name="distance")
kdtree_da = xr.DataArray(data=is2_kdtree, dims=is2_da.dims, coords=is2_da.coords, attrs=is2_da.attrs, name=is2_da.name)

if show_plots==True: 
    
    # Plot interpolation results
    pl_kdtree = interactiveArcticMaps(kdtree_da, title = "Gaussian smoothing with KDTree, "+date, vmin=0, vmax=4, frame_width=350, colorbar=False)
    pl_gaus = interactiveArcticMaps(gaus_da, title = "Gaussian smoothing interpolation, "+date, vmin=0, vmax=4, frame_width=350, colorbar=False)
    pl_raw = interactiveArcticMaps(is2_da, title = "IS2 raw data, "+date, vmin=0, vmax=4, frame_width=350, colorbar=True)
    display(pl_kdtree+pl_gaus+pl_raw)

    # Plot differences 
    pl_interp_all = interactiveArcticMaps(gaus_da.where(xr.ufuncs.isnan(is2_da)), vmin=0, vmax=4, title = "All smoothed grid cells", frame_width=350, colorbar=False)
    pl_cells_removed_by_kdtree = interactiveArcticMaps(gaus_da.where(xr.ufuncs.isnan(kdtree_da)), vmin=0, vmax=4, title = "Grid cells removed by KDTree", frame_width=350, colorbar=False)
    pl_cells_removed_by_interp = interactiveArcticMaps(is2_da.where(xr.ufuncs.isnan(gaus_da)), vmin=0, vmax=4, title = "Grid cells removed by Gaussian smoothing", frame_width=350, colorbar=True)
    display(pl_cells_removed_by_kdtree+pl_interp_all+pl_cells_removed_by_interp)
    
    # Plot distances 
    pl_dist = interactiveArcticMaps(dist_da, title="Distance from closest cell with data, "+date, frame_width=350, cmap="coolwarm", vmax=1.5*10**6, colorbar=True)
    display(pl_dist)